import pandas as pd
import tkinter as tk
from tkinter import ttk, filedialog
import esa_overlap
from collections import defaultdict
import numpy as np
from checkboxtreeview import CheckboxTreeview

def main():
    overlap_gui = OverlapGui()
    overlap_gui.root.mainloop()

class OverlapGui:
    def __init__(self):
        
        # Get file paths to CoA use-site acreage inputs
        self.get_coa_file_paths()
        
        # Get unique use-sites for use-site selection tab
        self.crops = sorted(list(pd.read_csv(self.county_coa_file_path)["CONCAT USE SITE"].unique()))
        self.selected_crops = [False]*len(self.crops)
        
        # Initialize the GUI
        self.root=tk.Tk()
        self.root.title('Overlap Analysis Tool')
        
        # Create tabs for general settings, crops, and geo exclusions
        self.initialize_tabs()
        
        # Populate the general tab
        self.generate_general_tab()
        
        # Populate the crop selection tab
        self.generate_crop_tab()
        
        # Populate the crop group tab
        self.generate_crop_group_tab()
        
        # Populate the geo restriction tab
        self.generate_geo_tab()

        # Pack the tabs
        self.tab_control.pack(expand=1, fill="both")
        
    def initialize_tabs(self):
        # Create notebook
        self.tab_control = ttk.Notebook(self.root)
        
        # Create a tab for general settings
        self.general_tab = ttk.Frame(self.tab_control)
    
        # Create a tab for crop selection
        self.crop_tab = ttk.Frame(self.tab_control)
        
        # Create a tab for crop group selection
        self.crop_group_tab = ttk.Frame(self.tab_control)
    
        # Create a tab for geo exclusion
        self.geo_tab = ttk.Frame(self.tab_control)
    
        # Format the tabs
        self.tab_control.add(self.general_tab, text = "General")
        self.tab_control.add(self.crop_tab, text = "Individual Use-Sites")
        self.tab_control.add(self.crop_group_tab, text = "Grouped Use-Sites")
        self.tab_control.add(self.geo_tab, text = "Geo Restrictions")
        
    def get_coa_file_paths(self):
        self.county_coa_file_path = r"input files/coa_county_acres_updated_2022_01_21.csv"
        self.state_coa_file_path = r"input files/coa_state_acres_updated_2022_01_21.csv"
        self.national_coa_file_path = r"input files/coa_national_acres_updated_2022_01_21.csv"
        
    def get_crop_group_file_paths(self):
        self.crop_group_list_file_path = r"input files/crop_group_list.txt"
        self.crop_subgroup_list_file_path = r"input files/crop_subgroup_list.txt"
        self.use_site_crop_group_crosswalk_file_path = r"input files/use_site_crop_group_crosswalk.txt"
        
    def get_species_file_paths(self):
        self.species_range_file_path = r"input files/compiled_species_range.csv"
        self.species_ch_file_path = r"input files/compiled_species_crit_habitat.csv"
        self.master_species_file_path = r"input files/MasterListESA_Oct2022_20221027.csv"
        
    def get_row_crop_list_file_path(self):
        self.row_crop_list_file_path = r"input files/row_crop_list.txt"
        
    def generate_general_tab(self):
        
        # Create file selection box
        self.generate_output_file_location_selection()
        
        # Create a box for chemical name
        chem_selection_label = ttk.Label(self.general_tab, text="Enter Chemical Name:")
        chem_selection_label.pack(fill='x', padx=5, pady=5)
        self.chem_name = tk.StringVar()
        chem_selection = tk.Entry(self.general_tab, textvariable=self.chem_name)
        chem_selection.pack(fill='x', padx=5, pady=5)
        
        # Create range vs crit habitat toggle
        self.generate_ch_vs_range_selection()
        
        # Create a checkbox for toggling the flag
        self.include_flags = tk.BooleanVar()
        self.include_flags.set(False)
        flag_selection_button = tk.Checkbutton(self.general_tab,
                                               text = "Include crop area imputation flags in output",
                                               variable = self.include_flags,
                                               onvalue = True,
                                               offvalue = False)
        flag_selection_button.pack(fill='x', padx=5, pady=5)
        
        
        
        # Create the submit button
        submit_button = tk.Button(self.general_tab,text = "Submit",command= lambda: self.run_overlap_analysis())
        submit_button.pack(fill='x', padx=5, pady=5)
        
    def generate_output_file_location_selection(self):
        # Generate output file selection label
        label_crop = ttk.Label(self.general_tab, text="Select output file location:")
        label_crop.pack(fill='x', padx=5, pady=5)
        
        # Initialize a variable to hold output location
        self.output_file_location = tk.StringVar()

        output = tk.Entry(self.general_tab, textvariable=self.output_file_location)
        output.pack(fill='x', padx=5, pady=5)
        
        submit_button = tk.Button(self.general_tab, text = "Select Location",command= lambda: self.select_output_file_location())
        submit_button.pack(fill='x', padx=5, pady=5)
        
    def generate_ch_vs_range_selection(self):
        
        #Initialize a variable to hold species range vs critical habitat selection
        self.is_range = tk.StringVar(value = 1)
        
        # Crit Range vs critical habitat selection label
        ch_vs_range_label = ttk.Label(self.general_tab, text="Select species dataset:")
        ch_vs_range_label.pack(fill='x', padx=5, pady=5)
        
        # Generate range button
        range_button = tk.Radiobutton(self.general_tab, text = "Range", variable = self.is_range, value = 1)
        range_button.pack(fill='x', padx=5, pady=5)
        
        # Generate CH button
        range_button = tk.Radiobutton(self.general_tab, text = "Critical Habitat", variable = self.is_range, value = 0)
        range_button.pack(fill='x', padx=5, pady=5)
        
    def generate_crop_tab(self):
        label_crop = ttk.Label(self.crop_tab, text="Select crops to be included in the overlap analysis:")
        label_crop.pack(fill='x', padx=5, pady=5)
        
        # Add frame with scrollbar to crop tab to allow vertical scrolling
        self.crop_canvas = tk.Canvas(self.crop_tab, borderwidth=0, background="#ffffff")
        self.crop_canvas_frame = tk.Frame(self.crop_canvas, background="#ffffff")
        self.crop_canvas_frame_sb = tk.Scrollbar(self.crop_tab, orient="vertical", command=self.crop_canvas.yview)
        self.crop_canvas.configure(yscrollcommand=self.crop_canvas_frame_sb.set)
        self.crop_canvas_frame_sb.pack(side="right", fill="y")
        self.crop_canvas.pack(side="left", fill="both", expand=True)
        self.crop_canvas.create_window((4,4), window=self.crop_canvas_frame, anchor="nw", tags="self.crop_canvas_frame")
        self.crop_canvas_frame.bind("<Configure>", self.onFrameConfigure)
        self.generate_crop_checkboxes()
    
    
    def onFrameConfigure(self, event):
        self.crop_canvas.configure(scrollregion=self.crop_canvas.bbox("all"))
        
    def select_output_file_location(self):
        FILEOPENOPTIONS = dict(defaultextension='.csv',
                  filetypes=[('Excel file','*.xlsx'),('All files','*.*')])
        self.output_file_location.set(filedialog.asksaveasfilename(parent=self.general_tab, **FILEOPENOPTIONS))
        
    def generate_crop_checkboxes(self):
        self.crop_checks = []
        self.cb_vars = []
        
        # Loop through crops and add a checkbox entry to the list for each
        for crop in self.crops:
            current_cb_var = tk.BooleanVar()
            current_cb_var.set(False)
            new_check = ttk.Checkbutton(self.crop_canvas_frame, text=crop,
                                        variable = current_cb_var)
            self.crop_checks.append([crop,new_check])
            self.cb_vars.append([crop,current_cb_var])
            new_check.pack(fill='x', padx=5, pady=5)
            
    def generate_geo_tab(self):
        self.import_states()
        
        # Generate label for geo exclusion template
        self.generate_use_site_restriction_file_location_selection()
        
        
        
        # Generate label for state exclusion list box
        label_geo_state = ttk.Label(self.geo_tab, text=r"Select states/territories to be excluded from the overlap analysis:")
        label_geo_state.pack(fill='x', padx=5, pady=5)
        
        
        # Add frame with scrollbar to crop tab to allow vertical scrolling
        self.geo_canvas = tk.Canvas(self.geo_tab, borderwidth=0, background="#ffffff")
        self.geo_canvas_frame = tk.Frame(self.geo_canvas, background="#ffffff")
        self.geo_canvas_frame_sb = tk.Scrollbar(self.geo_tab, orient="vertical", command=self.geo_canvas.yview)
        self.geo_canvas.configure(yscrollcommand=self.geo_canvas_frame_sb.set)
        self.geo_canvas_frame_sb.pack(side="right", fill="y")
        self.geo_canvas.pack(side="left", fill="both", expand=True)
        self.geo_canvas.create_window((4,4), window=self.geo_canvas_frame, anchor="nw", tags="self.geo_canvas_frame")
        self.geo_canvas_frame.bind("<Configure>", self.onFrameConfigureGeo)
        
        self.generate_state_checkboxes()
        
    def onFrameConfigureGeo(self, event):
        self.geo_canvas.configure(scrollregion=self.geo_canvas.bbox("all"))
        
        
    def import_states(self):
        self.unique_states = sorted(list(pd.read_csv(self.county_coa_file_path,usecols=["STATE_NAME"])["STATE_NAME"].unique()))
        
    def assign_row_crop(self, coa_county_crop_df):
        
        # Get the row crop list file path
        self.get_row_crop_list_file_path()
        
        # Open the file and read in the list of row crops
        with open(self.row_crop_list_file_path) as f:
            row_crops = set(f.read().splitlines())
        
        is_row_crop = defaultdict(bool)
        for i in sorted(list(coa_county_crop_df['CONCAT USE SITE'].unique())):
            is_row_crop[i] = i in row_crops
        
        return is_row_crop
        
        
    def generate_state_checkboxes(self):
        self.state_check = []
        self.state_vars = []
        
        for state in self.unique_states:
            current_state_var = tk.BooleanVar()
            new_check = ttk.Checkbutton(self.geo_canvas_frame, text=state,
                                        variable = current_state_var)
            self.crop_checks.append([state,new_check])
            self.state_vars.append([state,current_state_var])
            new_check.pack(fill='x', padx=5, pady=5)
            
    def generate_use_site_restriction_file_location_selection(self):
        # Generate output file selection label
        label_geo_import = ttk.Label(self.geo_tab, text="Select file location for crop-specific county- and state-level exclusions:")
        label_geo_import.pack(fill='x', padx=5, pady=5)
        
        # Initialize a variable to hold output location
        self.geo_restriction_file_location = tk.StringVar()

        output = tk.Entry(self.geo_tab, textvariable=self.geo_restriction_file_location)
        output.pack(fill='x', padx=5, pady=5)
        
        submit_button = tk.Button(self.geo_tab, text = "Select Location",command= lambda: self.select_restriction_file_location())
        submit_button.pack(fill='x', padx=5, pady=5)
        
    def select_restriction_file_location(self):
        FILEOPENOPTIONS = dict(defaultextension='.csv',
                  filetypes=[('XLSX file','*.xlsx'),('All files','*.*')])
        self.geo_restriction_file_location.set(filedialog.askopenfilename(parent=self.geo_tab, **FILEOPENOPTIONS))
        
    def run_overlap_analysis(self):
        
        # Get the locations of files related to species
        self.get_species_file_paths()
        
        # Compile crop groups and get a list of selected crops
        self.compile_checked_crop_subgroups()
        
        # Import the input COA files
        coa_county_df = esa_overlap.import_coa_county_df(self.county_coa_file_path)
        coa_state_df = esa_overlap.import_coa_state_df(self.state_coa_file_path)
        coa_national_df = esa_overlap.import_coa_national_df(self.national_coa_file_path)
        
        # Import the species file
        if self.is_range.get()=="1":
            species_df = esa_overlap.import_species_df(self.species_range_file_path)
        else:
            species_df = esa_overlap.import_species_df(self.species_ch_file_path)
            
        # Filter COA dataframes for selected crops
        coa_county_df = coa_county_df[coa_county_df['CONCAT USE SITE'].isin(self.selected_crops)]
        coa_state_df = coa_state_df[coa_state_df['CONCAT USE SITE'].isin(self.selected_crops)]
        coa_national_df = coa_national_df[coa_national_df['CONCAT USE SITE'].isin(self.selected_crops)]
        
        # Apply geo-restrictions
        coa_county_df = self.get_geo_restricted_county_df(coa_county_df)
        
        # Group crops for analysis
        coa_county_df, coa_state_df, coa_national_df = self.format_coa_tables_for_grouped_output(coa_county_df, coa_state_df, coa_national_df)
        
        # Determine whether each included crop/group represents a row crop (for buffering)
        is_row_crop = self.assign_row_crop(coa_county_df)
        
        # Generate the output
        esa_overlap.generate_full_magtool_output(coa_county_df, coa_state_df, coa_national_df, species_df,
                                                 self.output_file_location.get(), self.chem_name.get(),
                                                 self.master_species_file_path, self.include_flags.get(),
                                                 is_row_crop)
        
    def get_geo_restricted_county_df(self, county_df):
        # Apply all-crop state restrictions
        county_df = self.apply_full_state_restrictions(county_df)
        
        # If the more detailed geo restriction excel sheet was specified get the additional restrictions and apply them
        if self.geo_restriction_file_location.get():
            full_county_restrictions_list, specific_county_restrictions_dict = self.extract_county_restrictions_csv()
            full_state_restrictions_list, specific_state_restrictions_dict = self.extract_state_restrictions_csv()
            
            if full_county_restrictions_list or specific_county_restrictions_dict:
                county_df = self.apply_county_restrictions(county_df, full_county_restrictions_list, specific_county_restrictions_dict)
            
            if full_state_restrictions_list or specific_state_restrictions_dict:
                county_df = self.apply_state_restrictions(county_df, full_state_restrictions_list, specific_state_restrictions_dict)
        
        return county_df
            
    def apply_full_state_restrictions(self, county_df):
        
        # Get a list of excluded states
        excluded_states = {i for i,j in self.state_vars if j.get()}
        
        # Remove rows corresponding to excluded states
        county_df = county_df[~county_df.STATE_NAME.isin(excluded_states)]
        
        return county_df
    
    def extract_county_restrictions_csv(self):
        county_restrictions_df = pd.read_excel(io = self.geo_restriction_file_location.get(), sheet_name = "County Exclusions")
        full_restrictions_list = set()
        specific_restrictions_dict = defaultdict(set)
        
        crop_columns = [i for i in county_restrictions_df.columns if "EXCLUDED USE-SITE" in i]
        for _,i in county_restrictions_df.iterrows():
            if i[crop_columns[0]].casefold() == "ALL".casefold():
                full_restrictions_list.add(i["GEOID"])
            else:
                current_crops = zip(i[crop_columns].to_numpy(),i[crop_columns].isna().to_numpy())
                for c,n in current_crops:
                    if not n:
                        specific_restrictions_dict[c].add(i["GEOID"])
        
        return full_restrictions_list, specific_restrictions_dict
    
    def extract_state_restrictions_csv(self):
        state_restrictions_df = pd.read_excel(io = self.geo_restriction_file_location.get(), sheet_name = "State Exclusions")
        full_restrictions_list = set()
        specific_restrictions_dict = defaultdict(set)
        
        crop_columns = [i for i in state_restrictions_df.columns if "EXCLUDED USE-SITE" in i]
        for _,i in state_restrictions_df.iterrows():
            if i[crop_columns[0]].casefold() == "ALL".casefold():
                full_restrictions_list.add(i["STATE"])
            else:
                current_crops = zip(i[crop_columns].to_numpy(),i[crop_columns].isna().to_numpy())
                for c,n in current_crops:
                    if not n:
                        specific_restrictions_dict[c].add(i["STATE"])
                
        
        return full_restrictions_list, specific_restrictions_dict
    
    def apply_county_restrictions(self, county_df, full_county_restrictions_list, specific_county_restrictions_dict):
        if full_county_restrictions_list:
            # Remove excluded counties from the acreage dataframe
            county_df = county_df[~county_df["GEOID"].isin(full_county_restrictions_list)]
        if specific_county_restrictions_dict:
            for i in specific_county_restrictions_dict:
                if i in county_df["CONCAT USE SITE"].unique():
                    county_df = county_df[~((county_df["GEOID"].isin(specific_county_restrictions_dict[i])) & 
                                            (county_df["CONCAT USE SITE"] == i))]
                    
        return county_df
                    
    def apply_state_restrictions(self, county_df, full_state_restrictions_list, specific_state_restrictions_dict):
        if full_state_restrictions_list:
            # Remove excluded states from the acreage dataframe
            county_df = county_df[~county_df["STATE_NAME"].isin(full_state_restrictions_list)]
        if specific_state_restrictions_dict:
            for i in specific_state_restrictions_dict:
                if i in county_df["CONCAT USE SITE"].unique():
                    county_df = county_df[~((county_df["STATE_NAME"].isin(specific_state_restrictions_dict[i])) &
                                            (county_df["CONCAT USE SITE"] == i))]
        
        return county_df
    
    def generate_crop_group_tab(self):
        
        # Import info required for crop group list
        self.import_crop_groups()
        
        # Generate the checkbox list
        self.generate_crop_group_checkboxes()
    
    def import_crop_groups(self):
        
        # Import required tables
        self.get_crop_group_file_paths()
        crop_group_df = pd.read_csv(self.crop_group_list_file_path, sep = "\t")
        crop_subgroup_df = pd.read_csv(self.crop_subgroup_list_file_path, sep = "\t")
        use_site_to_crop_group_df = pd.read_csv(self.use_site_crop_group_crosswalk_file_path, sep = "\t")
        
        # Get mapping of group ID to group name and preferred output name 
        self.crop_group_names = crop_group_df.set_index("Group ID")["Group Name"].to_dict()
        self.crop_subgroup_names = crop_subgroup_df.set_index(["Group ID","Subgroup ID"])["Subgroup Name"].to_dict()
        
        self.output_crop_group_names = crop_group_df.set_index("Group ID")["Final Label Name"].to_dict()
        self.output_crop_subgroup_names = crop_subgroup_df.set_index(["Group ID","Subgroup ID"])["Final Label Name"].to_dict()
        
        # Get mapping of subgroups to groups
        self.crop_group_dict = defaultdict(lambda: defaultdict(list))
        
        for _,i in use_site_to_crop_group_df.iterrows():
            if i["Group/Subgroup"] in self.crop_group_names:
                self.crop_group_dict[i["Group/Subgroup"]]["_"].append(i["CONCAT USE SITE"])
            else:
                self.crop_group_dict[i["Group/Subgroup"][:-1]][i["Group/Subgroup"][-1]].append(i["CONCAT USE SITE"])
    
    def generate_crop_group_checkboxes(self):
        
        #Initialize the hierarchical checkbox list
        self.tree = CheckboxTreeview(self.crop_group_tab)
        self.tree.pack(expand=1, fill=tk.BOTH)
        
        # Loop through groups / subgroups and crops and add them to the tree
        for group in sorted(self.crop_group_dict, key = lambda x: list(map(int,x.split("-")))):
            self.tree.insert("","end",group, text=group+". "+self.crop_group_names[group])
            for subgroup in sorted(self.crop_group_dict[group]):
                
                # If crops were assigned to a group without a subgroup use the group name
                if subgroup == "_":
                    self.tree.insert(group,"end","+".join([group,subgroup]),text = group+". "+self.crop_group_names[group])
                else:
                    self.tree.insert(group,"end","+".join([group,subgroup]),text = group+subgroup+". " + self.crop_subgroup_names[(group,subgroup)])
                for crop in sorted(self.crop_group_dict[group][subgroup]):
                    self.tree.insert("+".join([group,subgroup]),"end","+".join([group,subgroup,crop]),text = crop)
                    
                    
    def compile_checked_crop_subgroups(self):
        # Create a dictionary to store grouped crops
        self.compiled_crop_group_dict = defaultdict(set)
        
        # Create a set to store crops that have already been used in individual crops or groups (to prevent double counting)
        self.selected_crops = {i[0] for i in self.cb_vars if i[1].get()}
        
        # Add individual crops to the crop group dictionary with corresponding sets containing themselves
        for i in self.selected_crops:
            self.compiled_crop_group_dict[i] = {i}
        
        # Loop through checked elements of tree
        for i in self.tree.get_checked():
            
            # Extract group key, subgroup key, and crop name
            group, subgroup, crop = i.split("+")
            
            # If crop wasn't already used add it to the crop group
            if crop not in self.selected_crops:
                self.compiled_crop_group_dict[group,subgroup].add(crop)
                self.selected_crops.add(crop)
                
    def format_coa_tables_for_grouped_output(self, coa_county_df, coa_state_df, coa_national_df):
        coa_county_out = []
        coa_state_out = []
        coa_national_out = []
        
        for i in self.compiled_crop_group_dict:
            if isinstance(i, str):
                crops = self.compiled_crop_group_dict[i]
                # Add individual crop
                current_county_df = coa_county_df[coa_county_df["CONCAT USE SITE"].isin(crops)].copy()
                current_state_df = coa_state_df[coa_state_df["CONCAT USE SITE"].isin(crops)].copy()
                current_national_df = coa_national_df[coa_national_df["CONCAT USE SITE"].isin(crops)].copy()
                
                # Add the resulting dataframes to the lists for concatenation
                coa_county_out.append(current_county_df)
                coa_state_out.append(current_state_df)
                coa_national_out.append(current_national_df)
            
            else:
                g, sg = i
                crops = self.compiled_crop_group_dict[i]
                current_county_df = coa_county_df[coa_county_df["CONCAT USE SITE"].isin(crops)].copy()
                current_state_df = coa_state_df[coa_state_df["CONCAT USE SITE"].isin(crops)].copy()
                current_national_df = coa_national_df[coa_national_df["CONCAT USE SITE"].isin(crops)].copy()
            
                ## COUNTY
            
                # Sort by imputation so more restrictive imputation is retained after groupby
                current_county_df = current_county_df.sort_values(by = "Imputation")
                
                # Groupby and assign sum of acreages as new acreage value
                current_county_df["VALUE"] = current_county_df.groupby(by="GEOID")["VALUE"].transform(sum)
                
                # Drop duplicates so each geoid has a single entry representing the sum of all crops
                current_county_df = current_county_df.drop_duplicates(subset = ["GEOID"])
                
                # Change the crop name to the group name
                if sg == "_":
                    current_county_df["CONCAT USE SITE"] = self.output_crop_group_names[g]
                else:
                    current_county_df["CONCAT USE SITE"] = self.output_crop_subgroup_names[(g,sg)]
            
                ## STATE
            
                # Sort by imputation so more restrictive imputation is retained after groupby
                current_state_df = current_state_df.sort_values(by = "Imputation")
            
                # Groupby and assign sum of acreages as new acreage value
                current_state_df["Value"] = current_state_df.groupby(by="Location")["Value"].transform(sum)
            
                # Drop duplicates so each state has a single entry representing the sum of all crops
                current_state_df = current_state_df.drop_duplicates(subset = ["Location"])
            
                # Change the crop name to the group name
                if sg == "_":
                    current_state_df["CONCAT USE SITE"] = self.output_crop_group_names[g]
                else:
                    current_state_df["CONCAT USE SITE"] = self.output_crop_subgroup_names[(g,sg)]
            
                ## NATIONAL    
            
                # Groupby and assign sum of acreages as new acreage value
                current_national_df["VALUE"] = current_national_df["VALUE"].sum()
            
                # Drop duplicates so each state has a single entry representing the sum of all crops
                current_national_df = current_national_df.drop_duplicates(subset = "VALUE")
            
                # Change the crop name to the group name
                if sg == "_":
                    current_national_df["CONCAT USE SITE"] = self.output_crop_group_names[g]
                else:
                    current_national_df["CONCAT USE SITE"] = self.output_crop_subgroup_names[(g,sg)]
                
                # Add the resulting dataframes to the lists for concatenation
                coa_county_out.append(current_county_df)
                coa_state_out.append(current_state_df)
                coa_national_out.append(current_national_df)
            
        return pd.concat(coa_county_out), pd.concat(coa_state_out), pd.concat(coa_national_out)

    
main()